{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "# Convolutional Neural Networks\n",
    "\n",
    "__Content Creator:__ Guangyu Robert Yang\n",
    "\n",
    "__Content Modified:__ Kai Chen\n",
    "\n",
    "In this notebook, we will cover training a simple convolutional neural network on the MNIST dataset, visualizing tuning of model neurons, computing representational similarity analysis (RSA) for different model layers\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {}
   },
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import torch\n",
    "from torchvision import datasets\n",
    "import torchvision.transforms as transforms"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "## MNIST Dataset\n",
    "\n",
    "We will use the classical MNIST dataset because it allows rapid training on CPUs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# Download MNIST dataset\n",
    "path = Path('./')\n",
    "transform = transforms.Compose(\n",
    "    [transforms.ToTensor(),\n",
    "     transforms.Normalize((0.1307,), (0.3081,)),\n",
    "     ])\n",
    "trainset = datasets.MNIST(path, train=True, download=True,transform=transform)\n",
    "train_loader = torch.utils.data.DataLoader(trainset, batch_size=16, shuffle=True)\n",
    "\n",
    "testset = datasets.MNIST(path, train=False, transform=transform)\n",
    "test_loader = torch.utils.data.DataLoader(testset, batch_size=256, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# Visualize an image from the dataset\n",
    "for i, data in enumerate(train_loader):\n",
    "    images, labels = data\n",
    "    break\n",
    "print('Shape of images from one batch : ', images.shape)\n",
    "print('Shape of labels from one batch : ', labels.shape)\n",
    "\n",
    "plt.imshow(images[0, 0])\n",
    "plt.title('Label {}'.format(labels[0]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "## Define a convolutional neural network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {}
   },
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "\n",
    "class Net(nn.Module):\n",
    "\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        # 1 input image channel, 6 output channels, 3x3 square convolution\n",
    "        # kernel\n",
    "        self.conv1 = nn.Conv2d(1, 6, 3, padding=1)\n",
    "        self.conv2 = nn.Conv2d(6, 16, 3, padding=1)\n",
    "        self.fc1 = nn.Linear(16 * 7 * 7, 120)  # 7*7 from image dimension\n",
    "        self.fc2 = nn.Linear(120, 84)\n",
    "        self.fc3 = nn.Linear(84, 10)\n",
    "\n",
    "        # Set whether to readout activation\n",
    "        self.readout = False\n",
    "\n",
    "    def forward(self, x):\n",
    "        # Max pooling over a (2, 2) window\n",
    "        l1 = F.max_pool2d(F.relu(self.conv1(x)), 2)\n",
    "        l2 = F.max_pool2d(F.relu(self.conv2(l1)), 2)\n",
    "        l2_flat = torch.flatten(l2, start_dim=1)  # flatten tensor, while keeping batch dimension\n",
    "        l3 = F.relu(self.fc1(l2_flat))\n",
    "        l4 = F.relu(self.fc2(l3))\n",
    "        y = self.fc3(l4)\n",
    "\n",
    "        if self.readout:\n",
    "            return {'l1': l1, 'l2': l2, 'l3': l3, 'l4': l4, 'y': y}\n",
    "        else:\n",
    "            return y"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "## Training the network\n",
    "\n",
    "The following code trains the above network on MNIST until it reaches 95% accuracy. It should take only ~500-1000 training steps."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {}
   },
   "outputs": [],
   "source": [
    "import torch.optim as optim\n",
    "\n",
    "# Instantiate the network and print information\n",
    "net = Net()\n",
    "print(net)\n",
    "\n",
    "# Use Adam optimizer\n",
    "optimizer = optim.Adam(net.parameters(), lr=0.01)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "# Train for only one epoch\n",
    "running_loss = 0\n",
    "running_acc = 0\n",
    "for i, data in enumerate(train_loader):\n",
    "    image, label = data\n",
    "\n",
    "    # in your training loop:\n",
    "    optimizer.zero_grad()   # zero the gradient buffers\n",
    "    output = net(image)\n",
    "    loss = criterion(output, label)\n",
    "    loss.backward()\n",
    "    optimizer.step()    # Does the update\n",
    "\n",
    "    # prediction\n",
    "    prediction = torch.argmax(output, axis=-1)\n",
    "    acc = torch.mean((label == prediction).float())\n",
    "\n",
    "    running_loss += loss.item()\n",
    "    running_acc += acc\n",
    "    if i % 100 == 99:\n",
    "        running_loss /= 100\n",
    "        running_acc /= 100\n",
    "        print('Step {}, Loss {:0.4f}, Acc {:0.3f}'.format(\n",
    "            i+1, running_loss, running_acc))\n",
    "        if running_acc > 0.95:\n",
    "            break\n",
    "        running_loss, running_acc = 0, 0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "## Compute representation similarity\n",
    "\n",
    "We will first compute the neural responses to a batch of images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {}
   },
   "outputs": [],
   "source": [
    "for i, data in enumerate(test_loader):\n",
    "    images, labels = data\n",
    "    break\n",
    "\n",
    "# Readout network activity\n",
    "net.readout = True\n",
    "activity = net(images)\n",
    "\n",
    "n_images = len(labels)\n",
    "\n",
    "ind = np.argsort(labels.numpy())\n",
    "images = images.detach().numpy()[ind]\n",
    "labels = labels.numpy()[ind]\n",
    "\n",
    "similarity = dict()\n",
    "for key, val in activity.items():\n",
    "    new_val = val.detach().numpy()[ind]\n",
    "    activity[key] = new_val\n",
    "    new_val = new_val.reshape(n_images, -1)\n",
    "    similarity[key] = np.corrcoef(new_val)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "Here we plot the representational dissimilarity matrix for neural responses of different layers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {}
   },
   "outputs": [],
   "source": [
    "for layer, sim in similarity.items():\n",
    "    plt.figure()\n",
    "    plt.imshow(1 - sim, vmin=0, vmax=2, cmap='bwr')\n",
    "    plt.colorbar()\n",
    "    plt.title('Layer ' + layer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "## Search for preferred stimulus for a given neuron\n",
    "\n",
    "We will optimize images (not weights) such as the activity of a chosen neuron is maximized"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# Freeze for parameters in the network\n",
    "for param in net.parameters():\n",
    "    param.requires_grad = False\n",
    "\n",
    "\n",
    "# Here syn_image is the variable to be optimized\n",
    "# Initialized randomly for search in parallel\n",
    "batch_size = 64\n",
    "image_size = [batch_size] + list(images.shape[1:])\n",
    "syn_image_init = np.random.rand(*image_size)\n",
    "syn_image = torch.tensor(syn_image_init, requires_grad=True, dtype=torch.float32)\n",
    "\n",
    "# Use Adam optimizer\n",
    "optimizer = optim.Adam([syn_image], lr=0.01)\n",
    "\n",
    "running_loss = 0\n",
    "for i in range(1000):\n",
    "    optimizer.zero_grad()   # zero the gradient buffers\n",
    "    syn_image.data.clamp_(min=0.0, max=1.0)\n",
    "    syn_image_transform = (syn_image - 0.1307) / 0.3081\n",
    "    act = net(syn_image_transform)\n",
    "\n",
    "    # Pick a neuron, and minimize its negative activity\n",
    "    neuron = act['l4'][:, 1]\n",
    "    if i == 0:\n",
    "        neuron_init = neuron.detach().numpy()\n",
    "\n",
    "    loss = -torch.mean(torch.square(neuron))\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "    running_loss += loss.item()\n",
    "    if i % 100 == 99:\n",
    "        running_loss /= 100\n",
    "        print('Step {}, Loss {:0.4f}'.format(i+1, running_loss))\n",
    "        running_loss = 0\n",
    "\n",
    "neuron = neuron.detach().numpy()\n",
    "syn_image = syn_image.detach().numpy()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "Plot the neural activity driven by a batch of images before and after the optimization process. After optimization, many images can drive the unit strongly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {}
   },
   "outputs": [],
   "source": [
    "bins = np.linspace(0, np.max(neuron), 30)\n",
    "plt.figure()\n",
    "plt.hist(neuron, bins, label='After', alpha=0.5)\n",
    "plt.hist(neuron_init, bins, label='Before', alpha=0.5)\n",
    "plt.xlabel('Neural activity')\n",
    "plt.ylabel('Number of images searched')\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "Visualize the image that most strongly activates the chosen unit, before and after optimization."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {}
   },
   "outputs": [],
   "source": [
    "ind = np.argsort(neuron)\n",
    "plt.figure()\n",
    "plt.title('Before')\n",
    "plt.imshow(syn_image_init[ind[-1], 0], vmin=0, vmax=1)\n",
    "plt.colorbar()\n",
    "\n",
    "plt.figure()\n",
    "plt.title('After')\n",
    "plt.imshow(syn_image[ind[-1], 0], vmin=0, vmax=1)\n",
    "plt.colorbar()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {}
   },
   "source": [
    "# Supplementary materials\n",
    "\n",
    "Code for making publication quality figures"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {}
   },
   "outputs": [],
   "source": [
    "from mpl_toolkits.axes_grid1 import make_axes_locatable\n",
    "import matplotlib as mpl\n",
    "from pathlib import Path\n",
    "\n",
    "mpl.rcParams['font.size'] = 7\n",
    "mpl.rcParams['pdf.fonttype'] = 42\n",
    "mpl.rcParams['ps.fonttype'] = 42\n",
    "mpl.rcParams['font.family'] = 'arial'\n",
    "\n",
    "# pick just several images per class\n",
    "ind_im = np.concatenate([np.where(labels == l)[0][:5] for l in range(10)])\n",
    "\n",
    "ticks = [np.mean(np.where(labels[ind_im] == l)) for l in range(10)]\n",
    "\n",
    "for layer, sim in similarity.items():\n",
    "    dissim = 1 - sim[ind_im, :][:, ind_im]\n",
    "    # Convert to percentile\n",
    "    dissim_flat = dissim.flatten()\n",
    "    dissim_flat_new = np.zeros(len(dissim_flat))\n",
    "    for i, ind in enumerate(np.argsort(dissim_flat)):\n",
    "        dissim_flat_new[ind] = i\n",
    "\n",
    "    dissim_flat = dissim_flat_new / len(dissim_flat) * 100\n",
    "    dissim = dissim_flat.reshape(dissim.shape)\n",
    "\n",
    "    fig = plt.figure(figsize=(3, 3))\n",
    "    ax = fig.add_axes([.1, .1, .8, .8])\n",
    "    im = ax.imshow(dissim, cmap='cool', vmin=0, vmax=100, aspect=1,\n",
    "                   extent=(-0.5, dissim.shape[0]-0.5, dissim.shape[1]-0.5, -0.5),\n",
    "                   interpolation='nearest'\n",
    "                  )\n",
    "    if layer != 'y':\n",
    "        title = 'Layer ' + layer[-1]\n",
    "    else:\n",
    "        title = 'Output layer'\n",
    "    plt.title(title)\n",
    "    for loc in ['left', 'right', 'top', 'bottom']:\n",
    "        ax.spines[loc].set_visible(False)\n",
    "\n",
    "    ax.set_xticks(ticks)\n",
    "    ax.set_xticklabels([str(i) for i in range(10)])\n",
    "    ax.set_yticks([-0.5] + ticks + [dissim.shape[0]-0.5])\n",
    "    ax.set_yticklabels([''] + [str(i) for i in range(10)] + [''])\n",
    "    ax.tick_params(length=0)\n",
    "\n",
    "    divider = make_axes_locatable(ax)\n",
    "    cax = divider.append_axes(\"right\", size=\"5%\", pad=0.05)\n",
    "    cb = plt.colorbar(im, cax=cax, ticks=[0, 100])\n",
    "    cb.set_label('Dissimilarity', labelpad=-7)\n",
    "    cb.outline.set_linewidth(0)\n",
    "\n",
    "    fname = Path('RSA_' + layer)\n",
    "    fig.savefig(fname.with_suffix('.pdf'), transparent=True)\n",
    "    fig.savefig(fname.with_suffix('.png'), dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {}
   },
   "outputs": [],
   "source": [
    "# Plot network activity\n",
    "i_image = 0\n",
    "\n",
    "fig = plt.figure(figsize=(1.5, 1.5))\n",
    "plt.imshow(images[i_image, 0])\n",
    "fname = Path('sample_cnn_image')\n",
    "plt.axis('off')\n",
    "fig.savefig(fname.with_suffix('.pdf'), transparent=True)\n",
    "fig.savefig(fname.with_suffix('.png'), dpi=300)\n",
    "\n",
    "layers = ['l1', 'l2', 'l3', 'l4', 'y']\n",
    "\n",
    "for layer in layers:\n",
    "    act = activity[layer]\n",
    "    act = act[i_image]\n",
    "    if len(act.shape) == 3:\n",
    "        n_channels = act.shape[0]\n",
    "        if n_channels == 6:\n",
    "            n_x, n_y = 2, 3\n",
    "        elif n_channels == 16:\n",
    "            n_x, n_y = 4, 4\n",
    "        else:\n",
    "            n_x, n_y = n_channels, 1\n",
    "        vmax = np.max(act)\n",
    "        fig, axs = plt.subplots(n_y, n_x, figsize=(1.5/n_y*n_x, 1.5))\n",
    "        for i_channel in range(n_channels):\n",
    "            ax = axs[np.mod(i_channel, n_y), i_channel//n_y]\n",
    "            ax.imshow(act[i_channel], vmin=0, vmax=vmax)\n",
    "            ax.set_axis_off()\n",
    "        plt.tight_layout()\n",
    "    elif len(act.shape) == 1:\n",
    "        fig = plt.figure(figsize=(0.5, 1.5))\n",
    "        plt.imshow(act[:, np.newaxis], aspect='auto')\n",
    "        plt.axis('off')\n",
    "\n",
    "    fname = Path('sample_cnn_activity_' + layer)\n",
    "    fig.savefig(fname.with_suffix('.pdf'), transparent=True)\n",
    "    fig.savefig(fname.with_suffix('.png'), dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {}
   },
   "outputs": [],
   "source": [
    "def get_syn_image(layer):\n",
    "    # Here syn_image is the variable to be optimized\n",
    "    # Initialized randomly for search in parallel\n",
    "    batch_size = 64\n",
    "    image_size = [batch_size] + list(images.shape[1:])\n",
    "    syn_image_init = np.random.rand(*image_size)\n",
    "    syn_image = torch.tensor(syn_image_init, requires_grad=True, dtype=torch.float32)\n",
    "\n",
    "    # Use Adam optimizer\n",
    "    optimizer = optim.Adam([syn_image], lr=0.01)\n",
    "\n",
    "    running_loss = 0\n",
    "    running_loss_reg = 0\n",
    "    for i in range(1000):\n",
    "        optimizer.zero_grad()   # zero the gradient buffers\n",
    "        syn_image.data.clamp_(min=0.0, max=1.0)\n",
    "        syn_image_transform = (syn_image - 0.1307) / 0.3081\n",
    "        activity = net(syn_image_transform)\n",
    "\n",
    "        # Pick a neuron, and minimize its negative activity\n",
    "        neuron = activity[layer]\n",
    "\n",
    "        # Choose a neuron that is already most activated\n",
    "        if i == 0:\n",
    "            neuron_avg = np.mean(neuron.detach().numpy(), axis=0)\n",
    "            ind = np.argsort(neuron_avg.flatten())[-1]\n",
    "            print('Chosen unit', ind)\n",
    "\n",
    "        neuron = neuron.view(batch_size, -1)[:, ind]\n",
    "        if i == 0:\n",
    "            neuron_init = neuron.detach().numpy()\n",
    "\n",
    "        loss = -torch.mean(torch.square(neuron))\n",
    "        loss_reg = torch.mean(torch.square(syn_image_transform)) * 100\n",
    "        loss += loss_reg\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        running_loss += loss.item()\n",
    "        running_loss_reg += loss_reg.item()\n",
    "        if i % 100 == 99:\n",
    "            running_loss /= 100\n",
    "            running_loss_reg /= 100\n",
    "            print('Step {}, Loss {:0.4f} Loss Regularization {:0.4f}'.format(\n",
    "                i+1, running_loss, running_loss_reg))\n",
    "            running_loss, running_loss_reg = 0, 0\n",
    "\n",
    "    neuron = neuron.detach().numpy()\n",
    "    syn_image = syn_image.detach().numpy()\n",
    "    return syn_image, syn_image_init, neuron\n",
    "\n",
    "layers = ['l1', 'l2', 'l3', 'l4', 'y']\n",
    "results = dict()\n",
    "for layer in layers:\n",
    "    print('Layer ', layer)\n",
    "    syn_image, syn_image_init, neuron = get_syn_image(layer)\n",
    "    results[layer] = {'syn_image': syn_image,\n",
    "                      'syn_image_init': syn_image_init,\n",
    "                      'neuron': neuron}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "execution": {}
   },
   "outputs": [],
   "source": [
    "for layer in layers:\n",
    "    res = results[layer]\n",
    "    ind = np.argsort(res['neuron'])\n",
    "\n",
    "    fig = plt.figure(figsize=(0.8, 0.8))\n",
    "    ax = fig.add_axes([0.05, 0.05, 0.9, 0.9])\n",
    "    ax.imshow(res['syn_image'][ind[-1], 0], vmin=0, vmax=1)\n",
    "    plt.axis('off')\n",
    "\n",
    "    fname = Path('cnn_tuning_' + layer)\n",
    "    fig.savefig(fname.with_suffix('.pdf'), transparent=True)\n",
    "    fig.savefig(fname.with_suffix('.png'), dpi=300)"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "include_colab_link": true,
   "name": "W5_Tutorial1",
   "toc_visible": true
  },
  "kernel": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
